"""
Phase 5: Bilateral Aggregation Bridge
=======================================
Side-by-side comparison of Z coefficients from:
  - Bilateral aggregated positions (summed CPIS/CDIS by reporter-year, GDP-normalized)
  - Multilateral IIP positions (IMF balance sheet)
  - CA on the intersection sample

Table 8: Bilateral vs multilateral Z₁ comparison
Table 8b: Partner-count heterogeneity
Table 8c: Sign/significance concordance summary
Table 8d: Bilateral as noise filter (excl financial centers)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.4f}{stars(p)}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        if name in DEMO_VARS:
            sig = stars(gls.pvalues[i])
            print(f"    {name:20s} {gls.beta[i]:10.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title, key_vars=None):
    if not results:
        return

    lines = [f"# {title}\n"]

    if key_vars is None:
        key_vars = []
        for r in results:
            for k in r:
                if k.endswith('_coef'):
                    v = k.replace('_coef', '')
                    if v not in key_vars:
                        key_vars.append(v)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in key_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    for stat, key, fmt_str in [('Dep var', 'dep_var', '{}'), ('N', 'n_obs', '{}'),
                                ('R²', 'r_squared', '{:.4f}'),
                                ('Countries', 'n_countries', '{}')]:
        row = f"| {stat} |"
        for r in results:
            row += f" {fmt_str.format(r[key])} |"
        lines.append(row)

    lines.append("\n*Panel GLS with AR(1) errors. Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 5: BILATERAL AGGREGATION BRIDGE")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]
    base_vars = DEMO_VARS + controls

    # ══════════════════════════════════════════════════════════════════
    # TABLE 8: Bilateral vs Multilateral Comparison
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 8: BILATERAL VS MULTILATERAL COMPARISON")
    print("=" * 50)

    # Check for properly GDP-normalized bilateral vars
    bilateral_vars = ['agg_portfolio_total_gdp', 'agg_portfolio_debt_gdp', 'agg_fdi_outward_gdp']
    bilateral_vars = [v for v in bilateral_vars if v in df.columns and df[v].notna().sum() > 50]

    if not bilateral_vars:
        print("  No bilateral aggregated variables found. Skipping.")
        print("\n" + "=" * 70)
        print("PHASE 5 COMPLETE (no bilateral data)")
        print("=" * 70)
        return

    for v in bilateral_vars:
        n = df[v].notna().sum()
        print(f"  {v}: {n} obs, mean={df[v].mean():.2f}% of GDP")

    # Intersection sample: countries with both bilateral aggregated AND IIP data
    mask = df[bilateral_vars[0]].notna() & df['gross_assets_gdp'].notna()
    df_int = df[mask].copy()
    print(f"  Intersection sample: {len(df_int)} obs, {df_int['iso3'].nunique()} countries")

    results_t8 = []

    # Paired comparisons: bilateral aggregated vs IIP equivalent
    pairs = [
        ('agg_portfolio_total_gdp', 'gross_assets_gdp', 'Portfolio total', 'Gross assets'),
        ('agg_portfolio_debt_gdp', 'debt_assets_gdp', 'Portfolio debt', 'Debt assets'),
        ('agg_fdi_outward_gdp', 'fdi_assets_gdp', 'FDI outward', 'FDI assets'),
    ]

    for bil_var, iip_var, bil_label, iip_label in pairs:
        if bil_var not in bilateral_vars:
            continue
        if iip_var not in df.columns:
            continue

        r_bil = run_gls(df_int, bil_var, base_vars, f'Bilat: {bil_label}')
        if r_bil: results_t8.append(r_bil)

        r_iip = run_gls(df_int, iip_var, base_vars, f'IIP: {iip_label}')
        if r_iip: results_t8.append(r_iip)

    # CA on same sample
    r_ca = run_gls(df_int, 'ca_gdp', base_vars, 'CA/GDP (intersect)')
    if r_ca: results_t8.append(r_ca)

    write_table(results_t8, "bilateral_bridge.md",
                "Table 8: Bilateral Aggregated vs Multilateral IIP Positions (% of GDP)",
                key_vars=DEMO_VARS + controls)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 8b: SIGN/SIGNIFICANCE CONCORDANCE
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("SIGN/SIGNIFICANCE CONCORDANCE")
    print("=" * 50)

    lines = ["# Table 8b: Bilateral vs Multilateral Z₁ Concordance\n"]
    lines.append("| Measure | Bilateral | IIP Equivalent | Same sign? | Both sig? |")
    lines.append("|:---|---:|---:|:---:|:---:|")

    for bil_var, iip_var, bil_label, iip_label in pairs:
        r_bil = None
        r_iip = None
        for r in results_t8:
            if r['dep_var'] == bil_var:
                r_bil = r
            if r['dep_var'] == iip_var:
                r_iip = r

        if r_bil and r_iip and 'Z_1_coef' in r_bil and 'Z_1_coef' in r_iip:
            b_bil = r_bil['Z_1_coef']
            b_iip = r_iip['Z_1_coef']
            p_bil = r_bil['Z_1_p']
            p_iip = r_iip['Z_1_p']
            same_sign = 'Yes' if (b_bil * b_iip > 0) else 'No'
            both_sig = 'Yes' if (p_bil < 0.1 and p_iip < 0.1) else 'No'

            lines.append(f"| {bil_label} vs {iip_label} | "
                        f"{b_bil:.2f}{stars(p_bil)} | {b_iip:.2f}{stars(p_iip)} | "
                        f"{same_sign} | {both_sig} |")
            print(f"  {bil_label:20s} Z₁={b_bil:8.2f}{stars(p_bil):3s}  vs  "
                  f"{iip_label:15s} Z₁={b_iip:8.2f}{stars(p_iip):3s}  "
                  f"sign={'=' if same_sign=='Yes' else '≠'}  sig={'Y' if both_sig=='Yes' else 'N'}")

    lines.append("\n*Bilateral positions aggregated from CPIS/CDIS by reporter-year, normalized by reporter GDP.*")
    lines.append("*IIP positions from IMF International Investment Position.*")
    lines.append("*Intersection sample restricted to country-years with both data sources.*")

    (OUT_TABLES / "bilateral_bridge_concordance.md").write_text('\n'.join(lines))
    print(f"\n  Saved: bilateral_bridge_concordance.md")

    # ══════════════════════════════════════════════════════════════════
    # PARTNER-COUNT HETEROGENEITY
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("PARTNER-COUNT HETEROGENEITY")
    print("=" * 50)

    if 'n_partners' in df_int.columns and df_int['n_partners'].notna().sum() > 100:
        median_partners = df_int['n_partners'].median()
        print(f"  Median partner count: {median_partners:.0f}")

        results_partners = []
        for label, mask_fn in [('Many partners', lambda x: x['n_partners'] >= median_partners),
                                ('Few partners', lambda x: x['n_partners'] < median_partners)]:
            sub = df_int[mask_fn(df_int)].copy()
            print(f"  {label}: {len(sub)} obs, {sub['iso3'].nunique()} countries")

            for dep, dep_label in [('agg_portfolio_total_gdp', 'bilat_total'),
                                    ('gross_assets_gdp', 'IIP_assets'),
                                    ('ca_gdp', 'CA')]:
                if dep in sub.columns:
                    short = f'{label}: {dep_label}'
                    r = run_gls(sub, dep, base_vars, short)
                    if r: results_partners.append(r)

        if results_partners:
            write_table(results_partners, "bilateral_bridge_partners.md",
                        "Table 8c: Partner-Count Heterogeneity",
                        key_vars=DEMO_VARS)

    # ══════════════════════════════════════════════════════════════════
    # TABLE 8d: BILATERAL AS NOISE FILTER — Excl Financial Centers
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 8d: BILATERAL AS NOISE FILTER (EXCL FC)")
    print("=" * 50)

    FINANCIAL_CENTERS = ['LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL']
    df_int_nofc = df_int[~df_int['iso3'].isin(FINANCIAL_CENTERS)].copy()
    print(f"  Intersection excl FC: {len(df_int_nofc)} obs, "
          f"{df_int_nofc['iso3'].nunique()} countries")

    results_t8d = []

    # Compare bilateral aggregated vs IIP — full intersection
    for bil_var, iip_var, bil_label, iip_label in pairs:
        if bil_var not in bilateral_vars or iip_var not in df.columns:
            continue

        # Full intersection
        r = run_gls(df_int, bil_var, base_vars, f'Full: {bil_label}')
        if r: results_t8d.append(r)
        r = run_gls(df_int, iip_var, base_vars, f'Full: {iip_label}')
        if r: results_t8d.append(r)

        # Excl FC intersection
        r = run_gls(df_int_nofc, bil_var, base_vars, f'ExFC: {bil_label}')
        if r: results_t8d.append(r)
        r = run_gls(df_int_nofc, iip_var, base_vars, f'ExFC: {iip_label}')
        if r: results_t8d.append(r)

    # Income balance on both samples
    if 'income_balance_gdp' in df.columns:
        r = run_gls(df_int, 'income_balance_gdp', base_vars, 'Full: income_bal')
        if r: results_t8d.append(r)
        r = run_gls(df_int_nofc, 'income_balance_gdp', base_vars, 'ExFC: income_bal')
        if r: results_t8d.append(r)

    write_table(results_t8d, "bilateral_bridge_noise_filter.md",
                "Table 8d: Bilateral as Noise Filter — Full vs Excl Financial Centers",
                key_vars=DEMO_VARS + controls)

    # Print signal-to-noise summary
    print("\n  Signal-to-noise summary:")
    for r in results_t8d:
        if 'Z_1_coef' in r and 'Z_1_se' in r and r['Z_1_se'] > 0:
            t_stat = r['Z_1_coef'] / r['Z_1_se']
            print(f"    {r['model']:30s} Z₁={r['Z_1_coef']:8.2f}  "
                  f"t={t_stat:6.2f}  p={r['Z_1_p']:.4f}")

    print("\n" + "=" * 70)
    print("PHASE 5 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
